Introduction to Reservoir Computing with ReservoirPy¶
Nathan Trouvain
Inria, IMN, LaBRI - Bordeaux, France
UCLA - November 14th 2023
Key oncepts and features ¶
- Numpy, Scipy, and that's it!
- Efficient execution (distributed implementation)
- Online and offline learning rules
- Complex model architectures enabled
- Documentation: https://reservoirpy.readthedocs.io/en/latest/
- GitHub: https://github.com/reservoirpy/reservoirpy
General info¶
- Everything is NumPy (and more generally "standard" scientific Python)
- First axis of arrays is always representing time.
Timeseries prediction ¶
The Lorenz attractor
$$ \begin{split} \dot{x}(t) &= \sigma (y(t) - x(t)) \\ \dot{y}(t) &= \rho x(t) - y(t) - x(t)z(t) \\ \dot{z}(t) &= x(t)y(t) - \beta z(t) \end{split} $$
- describes convection movements in a fluid. Highly chaotic!
from reservoirpy.datasets import lorenz
timesteps = 3000
X = lorenz(timesteps, x0=[17.67, 12.93, 43.91])
plot_lorenz(X, 1000)
Knowing series value at timestep $t$:
- How can we predict $t+1$, $t+100$...?
- How can we predict $t+1$, $t+2$, $\dots$, $t+n$ ?
10-steps ahead forecasting¶
Predict $P(t + 10)$ knowing $P(t)$.
from reservoirpy.datasets import to_forecasting
x, y = to_forecasting(X, forecast=10)
X_train1, y_train1 = x[:2000], y[:2000]
X_test1, y_test1 = x[2000:], y[2000:]
plot_train_test(X_train1, y_train1, X_test1, y_test1)
Some reservoir reminder¶
$$ x[t+1]= (1 - \alpha) x[t] + \alpha f(W \cdot x[t] + W_{in} \cdot u[t] + W_{fb} \cdot y[t]) $$ $$ y[t]= W_{out}^{\intercal} x[t] $$
ESN preparation¶
units = 100 # - number of units
leak_rate = 0.3 # - leaking rate
spectral_radius = 0.95 # - spectral radius
input_scaling = 0.5 # - input scaling (also called input gain)
connectivity = 0.1 # - recurrent weights connectivity probability
input_connectivity = 0.2 # - input weights connectivity probability
regularization = 1e-4 # - L2 regularization coeficient
transient = 100 # - number of warmup steps
seed = 1234 # - use for reproducibility
from reservoirpy.nodes import Reservoir, Ridge
reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
lr=leak_rate, rc_connectivity=connectivity,
input_connectivity=input_connectivity, seed=seed)
readout = Ridge(ridge=regularization)
esn = reservoir >> readout
reservoir_fb = reservoir << readout
esn_fb = reservoir_fb >> readout
ESN training¶
Learning is performed offline: model is updated only once, on all available training data.
esn.fit(X_train1, y_train1, warmup=transient);
plot_readout(readout)
ESN forecast¶
y_pred1 = esn.run(X_test1)
plot_results(y_pred1, y_test1)
Coefficient de détermination $R^2$ et erreur quadratique normalisée :
rsquare(y_test1, y_pred1), nrmse(y_test1, y_pred1)
(0.9966686152092658, 0.011448035696843583)
Closed-loop reservoir¶
- Train the ESN on solving a 1-step ahead prediction (learn the flow function $f(x_t) = x_{t+1})$
- Use ESN to predict on its own activity ("generative" mode).
units = 300 # - number of units
leak_rate = 0.3 # - leaking rate
spectral_radius = 1.25 # - spectral radius
input_scaling = 0.1 # - input scaling (also called input gain)
connectivity = 0.1 # - recurrent weights connectivity probability
input_connectivity = 0.2 # - input weights connectivity probability
regularization = 1e-4 # - L2 regularization coeficient
transient = 100 # - number of warmup steps
seed = 1234 # - use for reproducibility
Forecast of close future¶
esn = reset_esn()
x, y = to_forecasting(X, forecast=1)
X_train3, y_train3 = x[:2000], y[:2000]
X_test3, y_test3 = x[2000:], y[2000:]
esn = esn.fit(X_train3, y_train3, warmup=transient)
Closed-loop model¶
- 100 timesteps used as warmup;
- 300 timesteps created from reservoir dynamics, without external inputs.
seed_timesteps = 100
warming_inputs = X_test3[:seed_timesteps]
warming_out = esn.run(warming_inputs) # échauffement
nb_generations = 500
X_gen = np.zeros((nb_generations, 3))
y = warming_out[-1]
for t in range(nb_generations): # génération
y = esn(y)
X_gen[t, :] = y
X_t = X_test3[seed_timesteps: nb_generations+seed_timesteps]
plot_generation(X_gen, X_t, warming_out=warming_out, warming_inputs=warming_inputs)
plot_attractors(X_gen, X_t, warming_inputs, warming_out)
Online learning ¶
Online learning happens anytime and incrementally.
In the following, we will use Recursive Least Squares algorithm.
from reservoirpy.nodes import RLS
reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
lr=leak_rate, rc_connectivity=connectivity,
input_connectivity=input_connectivity, seed=seed)
readout = RLS() # Recursive Least Squares
esn_online = reservoir >> readout
Step-by-step training¶
outputs_pre = np.zeros(X_train1.shape)
for t, (x, y) in enumerate(zip(X_train1, y_train1)): # for each timestep do :
prediction = esn_online.train(np.atleast_2d(x), np.atleast_2d(y))
outputs_pre[t, :] = prediction
plot_results(outputs_pre, y_train1, sample=100)
plot_results(outputs_pre, y_train1, sample=500)
Whole sequence training¶
esn_online.train(X_train1, y_train1)
pred_online = esn_online.run(X_test1) # Wout is now learned and fixed
plot_results(pred_online, y_test1, sample=500)
Determination coefficient $R^2$ and NRMSE:
rsquare(y_test1, pred_online), nrmse(y_test1, pred_online)
(0.9954163172865621, 0.013428448857951327)
Diving in the reservoir¶
units = 300 # - number of units
leak_rate = 0.3 # - leaking rate
spectral_radius = 1.25 # - spectral radius
input_scaling = 0.1 # - input scaling (also called input gain)
connectivity = 0.1 # - recurrent weights connectivity probability
input_connectivity = 0.2 # - input weights connectivity probability
regularization = 1e-4 # - L2 regularization coeficient
transient = 100 # - number of warmup steps
seed = 1234 # - use for reproducibility
1. The spectral radius¶
The spectral radius is the recurrent weights matrix ($W$) largest absolute eigenvalue.
states = []
radii = [0.1, 1.25, 10.0]
for sr in radii:
reservoir = Reservoir(units, sr=sr, input_scaling=0.001, lr=leak_rate, rc_connectivity=connectivity,
input_connectivity=input_connectivity)
s = reservoir.run(X_test1[:500])
states.append(s)
units_nb = 20
plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
plt.subplot(len(radii)*100+10+i+1)
plt.plot(s[:, :units_nb], alpha=0.6)
plt.ylabel(f"$sr={radii[i]}$")
plt.xlabel(f"Activations ({units_nb} neurons)")
plt.show()
$-$ rayon spectral $\rightarrow$ stable dynamics
$+$ rayon spectral $\rightarrow$ chaotic dynamics
2. The input scaling¶
It is a coefficient applied to $W_{in}$. It can be seen as a gain applied on inputs.
states = []
scalings = [0.00001, 0.001, 2.0]
for iss in scalings:
reservoir = Reservoir(units, sr=spectral_radius, input_scaling=iss, lr=leak_rate,
rc_connectivity=connectivity, input_connectivity=input_connectivity)
s = reservoir.run(X_test1[:500])
states.append(s)
units_nb = 20
plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
plt.subplot(len(scalings)*100+10+i+1)
plt.plot(s[:, :units_nb], alpha=0.6)
plt.ylabel(f"$iss={scalings[i]}$")
plt.xlabel(f"Activations ({units_nb} neurons)")
plt.show()
Average correlation of reservoir states and inputs :
- $+$ input scaling $\rightarrow$ activity is bounded to input dynamics
- $-$ input scaling $\rightarrow$ activity is freely evolving
Input scaling may be used to adjust relative importance of different inputs.
3. The leaking rate¶
$$ x(t+1) = \underbrace{\color{red}{(1 - \alpha)} x(t)}_{\text{current state}} + \underbrace{\color{red}\alpha f(u(t+1), x(t))}_{\text{new inputs}} $$
with $\alpha \in [0, 1]$ and:
$$ f(u, x) = \tanh(W_{in} \cdotp u + W \cdotp x) $$
states = []
rates = [0.02, 0.2, 0.9]
for lr in rates:
reservoir = Reservoir(units, sr=spectral_radius, input_scaling=input_scaling, lr=lr,
rc_connectivity=connectivity, input_connectivity=input_connectivity)
s = reservoir.run(X_test1[:500])
states.append(s)
units_nb = 20
plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
plt.subplot(len(rates)*100+10+i+1)
plt.plot(s[:, :units_nb] + 2*i)
plt.ylabel(f"$lr={rates[i]}$")
plt.xlabel(f"States ({units_nb} neurons)")
plt.show()
- $+$ leaking rate $\rightarrow$ low inertia, short activity timescale
- $-$ leaking rate $\rightarrow$ strong inertia, strong activity timescale
The leaking rate is a proxy of the inverse of the reservoir neurons time constant.
Use case: falling robot ¶
features = ['com_x', 'com_y', 'com_z', 'trunk_pitch', 'trunk_roll', 'left_x', 'left_y',
'right_x', 'right_y', 'left_ankle_pitch', 'left_ankle_roll', 'left_hip_pitch',
'left_hip_roll', 'left_hip_yaw', 'left_knee', 'right_ankle_pitch',
'right_ankle_roll', 'right_hip_pitch', 'right_hip_roll',
'right_hip_yaw', 'right_knee']
prediction = ['fallen']
force = ['force_orientation', 'force_magnitude']
plot_robot(Y, Y_train, F)
ESN training¶
Using ESN class, an optimized and distributed implementation of Echo State Network.
from reservoirpy.nodes import ESN
reservoir = Reservoir(300, lr=0.5, sr=0.99, input_bias=False)
readout = Ridge(ridge=1e-3)
esn = ESN(reservoir=reservoir, readout=readout, workers=-1) # parallel computations: on
esn = esn.fit(X_train, y_train)
res = esn.run(X_test)
plot_robot_results(y_test, res)
print("Mean RMSE:", f"{np.mean(scores):.4f}", "±", f"{np.std(scores):.5f}")
print("Mean RMSE (with threshold):", f"{np.mean(filt_scores):.4f}", "±", f"{np.std(filt_scores):.5f}")
Mean RMSE: 0.1693 ± 0.10344 Mean RMSE (with threshold): 0.1443 ± 0.15187
acc = 0.0
for y_pred, y_true in zip(res, y_test):
true_fall = 1 if np.max(y_true) > 0.8 else 0
pred_fall = 1 if np.max(y_pred) > 0.8 else 0
acc += true_fall == pred_fall
print("Accuracy: ", acc / len(y_test))
Accuracy: 0.997229916897507
Use case: anytime decoding of canary songs¶
Dataset can be found on Zenodo: https://zenodo.org/record/4736597
Decoded song units: phrases, which are repetitions of identical syllables.
- One label per phrase/syllable type, with phrase onset and offset time.
- One SIL label used to denote silence.
im = plt.imread("./static/canary_outputs.png")
plt.figure(figsize=(15, 15)); plt.imshow(im); plt.axis('off'); plt.show()
ESN training¶
esn = esn.fit(X_train, y_train)
outputs = esn.run(X_test)
scores # for each song in the training set
[0.041898366578752774, 0.26824916196613413, 0.056066923530726016, 0.26186922452156974, 0.25279505065020313, 0.2973301480363923, 0.08465112967018373, 0.0757835005887135, 0.062293320736262446, 0.27808601581987047, 0.272282108659482, 0.06868623580211035, 0.08142962198384064, 0.21687078131635845, 0.08359475517547531, 0.25711765806631476, 0.05928056696981127, 0.06706991601289232, 0.30971726758731355, 0.2856620116257432, 0.23803067163266287, 0.2802303766631763, 0.30634429769730204, 0.07008888131638832, 0.058834894975027814, 0.08968311087061855, 0.05510946353548062, 0.31891940754185505, 0.06637726131910554, 0.28129431410600414, 0.06900564046960908, 0.08523421710318861, 0.13232757717138613, 0.06666551495548904, 0.25745571933867817, 0.23720755763796317, 0.07878070922226743, 0.0625701086286687, 0.3139616547204639, 0.06352163913912422, 0.05993568621517934, 0.31544378262388206, 0.16860638504178155, 0.07404083005603906, 0.05400370121040267, 0.26186867053526686, 0.3156938730251862, 0.3464776759964459, 0.13554030423369498, 0.32433297675110284, 0.2938915124219102, 0.054490005973757236, 0.25388728424724843, 0.2914105958165288, 0.12861581219218082, 0.25561541274249544, 0.08774021022750428, 0.13205119151984715, 0.2849402588666828, 0.0612416425822047, 0.21963345696920855, 0.2686952036803784, 0.09447458524250141, 0.3147311767427653, 0.19993804051811317, 0.06309105412531517, 0.2876680332273092, 0.2710702495770708, 0.31157197145084503, 0.31677865091658564, 0.17847684814796763, 0.0536745298782666, 0.05763961449656545, 0.1529957495646185, 0.05530489056886096, 0.0652782188606235, 0.32759459654382095, 0.25177867447452384, 0.06500028117330106, 0.055349985352973, 0.2991405297980643, 0.057690013962088546, 0.06670045395220371, 0.30107004822763334, 0.24714774977914808, 0.05278510393462161, 0.06492727060552393, 0.3130924452012823, 0.07941626537563042, 0.07586173824134203, 0.14306875521238774, 0.29999260194684224, 0.2569699887705938, 0.06266646755640856, 0.06113140263957045, 0.30826104112464153, 0.06301119438099242, 0.09062555280821274, 0.31894013846914704, 0.31108696163004074, 0.3011042331760643, 0.14054154336791733, 0.31502143360411583, 0.07179159212880215, 0.26237244859862635, 0.09780768918382465, 0.30128128971088425, 0.058959475074636565, 0.22269091839679295, 0.3174799889825151, 0.11293230368504535, 0.07604125737887724, 0.32541925846866065, 0.06650398004964037, 0.22678881066897535, 0.29061309235350924, 0.2893223750674847, 0.05581847677319, 0.21938855959020445, 0.3190807769011034, 0.06114835692220095, 0.05550619960719161, 0.2992285308230232, 0.18354501931770584, 0.07654727028275883, 0.13098451538222888, 0.08222210783630346, 0.3081013362893058, 0.25579231705772393, 0.22795829167918993, 0.16028318496900773, 0.09986265509513596, 0.22660181937197077, 0.23706153104418942, 0.05898395560168648, 0.05816246147341459, 0.06111031878839833, 0.04535935417557713, 0.08345579039091174, 0.05818286874600952, 0.06014739331844587, 0.2705214929637068, 0.25337770220061434, 0.06728420983991688, 0.3106767233405119, 0.0823876576456079, 0.07207548580210814, 0.08788158691227455, 0.15100806105878123, 0.19638124753817393, 0.1172075878504895, 0.0927045733533306, 0.19552183350011412, 0.057732861499848055, 0.06296012792960531, 0.28106645459325513, 0.25148070867426825, 0.05930186751695902, 0.29155637741433305, 0.08548715786601654, 0.2575952097280639, 0.07146088164251005, 0.05828980903430679, 0.06200418966505828, 0.08727325904797295, 0.30596675202953966, 0.05664384800993739, 0.05804706401987458, 0.33930209996543437, 0.06438522995406959, 0.21920742504384375, 0.2676698184264773, 0.06196287186117811, 0.06121119874255963, 0.32298490423958875, 0.3230490047365779, 0.09400340251447478, 0.2999982373975503, 0.3054456201679897, 0.06578350532402191, 0.23517439821901992, 0.16681391690414207, 0.05450495731823852, 0.11138643880328337, 0.23144724266655065, 0.04227352126156786, 0.27746677001590647, 0.12881643137349227, 0.179040558230622, 0.1341739040108836, 0.19360408529274117, 0.04783796456477225, 0.32421229149130587, 0.1845039590650022, 0.1995124525626012, 0.32043125531330485, 0.06845354607533762, 0.1458197670077707, 0.243786017331148, 0.052667285433581706, 0.24152115765115628, 0.2948981993752137, 0.05549456640921787, 0.05065018563903618, 0.05784093962477063, 0.11503369724757873, 0.05220633118474864, 0.27233135941218845, 0.29630925513349, 0.20291080002022982, 0.27182505542537516, 0.05674007779711345, 0.3029427192918892, 0.12178095050507337, 0.06821916114471359, 0.05960587557975479, 0.04786283982778066, 0.15773756384779067, 0.06060529874784104, 0.30264011106261013, 0.05491380807248753, 0.06311798606715395, 0.058306067168067285, 0.29951210898472896, 0.28482520378599707, 0.08260066083212717, 0.10478949296695929, 0.12399146300863935, 0.27129600554574557, 0.23006081901345518, 0.2722887451720606, 0.06192856728834221, 0.09371245781554265, 0.31883858241589225, 0.06251120679635971, 0.06232482514263497, 0.29544290830593106, 0.30232882561951957, 0.31765290619017167, 0.2089654544991453, 0.16308052509278148, 0.07488322936795097, 0.28894708753493564, 0.06384147370316678, 0.060202920133255386, 0.05931377049141252, 0.06671256624571793, 0.06042980435957483, 0.25714713018830215, 0.09407427887631757, 0.27864058604941605, 0.10288519236398824, 0.07529194511041419, 0.05551281765598819, 0.14356710225809258, 0.1324551204817483, 0.06332605634187974, 0.07961888295674165, 0.061275454758553734, 0.14147205673741706, 0.31796210922847623, 0.28388230888854465, 0.06229024177158459, 0.21333637473890732, 0.3497920074636554, 0.0608511471977061, 0.044825164368904616, 0.20390576768722812, 0.059990589907141185, 0.2787588800878586, 0.28272619438987867, 0.05969697167347789, 0.10179927104045379, 0.3075603731718322, 0.30690630754661313, 0.20715172077014768, 0.06207893460470899, 0.06868942288084497, 0.26405387142666453, 0.31312474681722235, 0.058690663604555285, 0.27694764961343016, 0.2686300661668503, 0.3167137640550521, 0.09493423609422344, 0.07832098335099139, 0.06332987762799028, 0.08927352275613562, 0.05760170771241061, 0.19327937494291486, 0.3069169872279629, 0.33795267922251676, 0.3102212782738736, 0.05802193906050374, 0.18732848565545993, 0.22806429047012128, 0.2589426611751848, 0.2996098525142755, 0.05708367378781728, 0.10157270842730629, 0.0601797421336286, 0.3236830802611832, 0.19145422286558833, 0.059729168503242745, 0.25365998410405843, 0.33576633450747023, 0.06102078873007076, 0.0809304029882015, 0.1106371165460654, 0.06168432562389803, 0.3142594876642518, 0.08394298411666416, 0.3262661430617617, 0.3095628341110285, 0.07075650218735702, 0.31562047893961037, 0.0989195865255671, 0.17605758224902546, 0.05574826215846168, 0.1161858127895835, 0.07487089099136783, 0.2979593522592164, 0.3266180269052033, 0.07278254996323831, 0.09186819763227758, 0.24943735725174745, 0.06108099737033062, 0.061227862818397885, 0.059440613606882194, 0.24299962136428674, 0.25786858297852244, 0.09752993315959181, 0.25400978680864544, 0.062466563820417674, 0.24069958848296183, 0.07077527380302613, 0.09689059904762588, 0.3122599027321018, 0.2471975312709046, 0.30182601510040263, 0.25946153057876964, 0.24368467767757784, 0.11933987916397464, 0.30764326218879195, 0.06644954387146716, 0.2589575593791584, 0.0568473351911274, 0.12133228089706342, 0.1836830872571596, 0.32484288159400415, 0.04979536670779161, 0.16857883078532115, 0.09104091552590132, 0.06665350184022081, 0.3054275636857228, 0.2416185432539922, 0.28970576966984296, 0.09769227160780768, 0.05937462204769736, 0.08793074557395594, 0.25738072365190173]
print("Précision moyenne :", f"{np.mean(scores):.4f}", "±", f"{np.std(scores):.5f}")
Précision moyenne : 0.9295 ± 0.02716
Thank you! It is now yours to try¶
Nathan Trouvain
Inria, IMN, LaBRI - Bordeaux, France
UCLA - November 14th 2023